{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mchorse/miniconda3/envs/logan/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '/home/mchorse/sparse_coding_hoagy/output_attn_sweep_tied_attn_l3_r4/_9/learned_dicts.pt'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 45\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mpickle\u001b[39;00m\n\u001b[1;32m      7\u001b[0m \u001b[39m# # Define the autoencoder so pickle knows how to serialize it. \u001b[39;00m\n\u001b[1;32m      8\u001b[0m \u001b[39m# # Later, we should actually save as a state_dict instead of a dumb pickle\u001b[39;00m\n\u001b[1;32m      9\u001b[0m \u001b[39m# class AutoEncoder(nn.Module):\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     43\u001b[0m \u001b[39m#         return next(self.parameters()).device\u001b[39;00m\n\u001b[1;32m     44\u001b[0m \u001b[39m# all_autoencoders = torch.load(\"/home/mchorse/sparse_coding_aidan_new/output_4_rd_deep/_7/learned_dicts.pt\")\u001b[39;00m\n\u001b[0;32m---> 45\u001b[0m all_autoencoders \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mload(\u001b[39m\"\u001b[39;49m\u001b[39m/home/mchorse/sparse_coding_hoagy/output_attn_sweep_tied_attn_l3_r4/_9/learned_dicts.pt\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[1;32m     46\u001b[0m num_dictionaries \u001b[39m=\u001b[39m \u001b[39mlen\u001b[39m(all_autoencoders)\n\u001b[1;32m     47\u001b[0m auto_num \u001b[39m=\u001b[39m \u001b[39m9\u001b[39m\n",
      "File \u001b[0;32m~/miniconda3/envs/logan/lib/python3.10/site-packages/torch/serialization.py:791\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, **pickle_load_args)\u001b[0m\n\u001b[1;32m    788\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m'\u001b[39m\u001b[39mencoding\u001b[39m\u001b[39m'\u001b[39m \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m pickle_load_args\u001b[39m.\u001b[39mkeys():\n\u001b[1;32m    789\u001b[0m     pickle_load_args[\u001b[39m'\u001b[39m\u001b[39mencoding\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m'\u001b[39m\n\u001b[0;32m--> 791\u001b[0m \u001b[39mwith\u001b[39;00m _open_file_like(f, \u001b[39m'\u001b[39;49m\u001b[39mrb\u001b[39;49m\u001b[39m'\u001b[39;49m) \u001b[39mas\u001b[39;00m opened_file:\n\u001b[1;32m    792\u001b[0m     \u001b[39mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[1;32m    793\u001b[0m         \u001b[39m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[1;32m    794\u001b[0m         \u001b[39m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[1;32m    795\u001b[0m         \u001b[39m# reset back to the original position.\u001b[39;00m\n\u001b[1;32m    796\u001b[0m         orig_position \u001b[39m=\u001b[39m opened_file\u001b[39m.\u001b[39mtell()\n",
      "File \u001b[0;32m~/miniconda3/envs/logan/lib/python3.10/site-packages/torch/serialization.py:271\u001b[0m, in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m    269\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[1;32m    270\u001b[0m     \u001b[39mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[0;32m--> 271\u001b[0m         \u001b[39mreturn\u001b[39;00m _open_file(name_or_buffer, mode)\n\u001b[1;32m    272\u001b[0m     \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    273\u001b[0m         \u001b[39mif\u001b[39;00m \u001b[39m'\u001b[39m\u001b[39mw\u001b[39m\u001b[39m'\u001b[39m \u001b[39min\u001b[39;00m mode:\n",
      "File \u001b[0;32m~/miniconda3/envs/logan/lib/python3.10/site-packages/torch/serialization.py:252\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m    251\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, name, mode):\n\u001b[0;32m--> 252\u001b[0m     \u001b[39msuper\u001b[39m()\u001b[39m.\u001b[39m\u001b[39m__init__\u001b[39m(\u001b[39mopen\u001b[39;49m(name, mode))\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/mchorse/sparse_coding_hoagy/output_attn_sweep_tied_attn_l3_r4/_9/learned_dicts.pt'"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformer_lens import HookedTransformer\n",
    "import numpy as np \n",
    "\n",
    "autoencoder_filename = \"/mnt/ssd-cluster/bigrun0308/tied_residual_l2_r2/_9/learned_dicts.pt\"\n",
    "auto_num = 9 # Selects which specific autoencoder to use\n",
    "all_autoencoders = torch.load(autoencoder_filename)\n",
    "num_dictionaries = len(all_autoencoders)\n",
    "autoencoder, hyperparams = all_autoencoders[auto_num]\n",
    "l1_alpha = hyperparams['l1_alpha']\n",
    "autoencoder2, hyperparams2 = all_autoencoders[auto_num+1]\n",
    "smaller_dict = autoencoder.get_learned_dict()\n",
    "larger_dict = autoencoder2.get_learned_dict()\n",
    "\n",
    "#Change these settings to load the correct autoencoder\n",
    "layer = 3\n",
    "setting = \"residual\"\n",
    "# setting = \"attention\"\n",
    "model_name = \"EleutherAI/pythia-70m-deduped\"\n",
    "\n",
    "device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = HookedTransformer.from_pretrained(model_name, device=device)\n",
    "\n",
    "if setting == \"residual\":\n",
    "    cache_name = f\"blocks.{layer}.hook_resid_post\"\n",
    "    neurons = model.cfg.d_model\n",
    "elif setting == \"mlp\":\n",
    "    cache_name = f\"blocks.{layer}.mlp.hook_post\"\n",
    "    neurons = model.cfg.d_mlp\n",
    "elif setting == \"attention\":\n",
    "    cache_name = f\"blocks.{layer}.hook_attn_out\"\n",
    "    neurons = model.cfg.d_model\n",
    "else:\n",
    "    raise NotImplementedError\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MCS\n",
    "Max cosine similarity between one dictionary & another one. If they learned the same feature, then they'll have high cosine similarity. \n",
    "\n",
    "If two dictionaries learned it, it's probably a real feature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('# of features above 0.9:', 383)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAdFUlEQVR4nO3df3DX9X3A8VdCTBDkGwSbBGYoarcqU+sVJnxru642M9PU6YlXe+VY2mN6tcGb5M4K1cKqXeFoT50eyuZacVctO3fVVUQsw4nXI/5olDuGymbFwo5+g54jQSwJkM/+2PjeItT6DSR5Jzwed9878/l8vp+8vvcWvk8++X7zLcuyLAsAgISUD/UAAADvJ1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBITsVQD9Afvb29sWvXrhg3blyUlZUN9TgAwIeQZVns3bs3Jk+eHOXlH3yNZFgGyq5du6K+vn6oxwAA+mHnzp1x+umnf+AxwzJQxo0bFxH/+wBzudwQTwMAfBhdXV1RX19ffB7/IMMyUA7/WCeXywkUABhmPszLM7xIFgBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJJTMdQDcHxMXfjEgJz3zWVNA3JeAPggrqAAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJCcYwqUZcuWRVlZWdx4443Fbfv374+WlpaYOHFinHLKKTF79uzo6Ojoc78dO3ZEU1NTjBkzJmpqauKmm26KgwcPHssoAMAI0u9AefHFF+Pv/u7v4vzzz++zfcGCBfH444/HI488Ehs3boxdu3bFVVddVdx/6NChaGpqip6enti0aVM8+OCDsWrVqli8eHH/HwUAMKL0K1DefffdmDNnTtx///1x6qmnFrd3dnbGD37wg7jjjjvi4osvjunTp8cDDzwQmzZtiueeey4iIn72s5/FK6+8Ej/60Y/iggsuiEsvvTRuv/32WLFiRfT09ByfRwUADGv9CpSWlpZoamqKhoaGPtvb29vjwIEDfbafffbZMWXKlGhra4uIiLa2tjjvvPOitra2eExjY2N0dXXF1q1bj/r9uru7o6urq88NABi5Kkq9w+rVq+Oll16KF1988Yh9hUIhKisrY/z48X2219bWRqFQKB7z/+Pk8P7D+45m6dKl8e1vf7vUUQGAYaqkKyg7d+6Mv/qrv4qHHnooRo8ePVAzHWHRokXR2dlZvO3cuXPQvjcAMPhKCpT29vbYvXt3fPKTn4yKioqoqKiIjRs3xt133x0VFRVRW1sbPT09sWfPnj736+joiLq6uoiIqKurO+JdPYe/PnzM+1VVVUUul+tzAwBGrpIC5fOf/3xs2bIlNm/eXLzNmDEj5syZU/zvk046KTZs2FC8z7Zt22LHjh2Rz+cjIiKfz8eWLVti9+7dxWPWr18fuVwupk2bdpweFgAwnJX0GpRx48bFueee22fb2LFjY+LEicXt8+bNi9bW1pgwYULkcrm44YYbIp/Px6xZsyIi4pJLLolp06bF3LlzY/ny5VEoFOLWW2+NlpaWqKqqOk4PCwAYzkp+kezvcuedd0Z5eXnMnj07uru7o7GxMe69997i/lGjRsWaNWvi+uuvj3w+H2PHjo3m5ua47bbbjvcoAMAwVZZlWTbUQ5Sqq6srqquro7Oz0+tR/s/UhU8MyHnfXNY0IOcF4MRTyvO3z+IBAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAklMx1AOQtqkLnxiwc7+5rGnAzg3A8OYKCgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkJySAuW+++6L888/P3K5XORyucjn8/Hkk08W9+/fvz9aWlpi4sSJccopp8Ts2bOjo6Ojzzl27NgRTU1NMWbMmKipqYmbbropDh48eHweDQAwIpQUKKeffnosW7Ys2tvb4xe/+EVcfPHFccUVV8TWrVsjImLBggXx+OOPxyOPPBIbN26MXbt2xVVXXVW8/6FDh6KpqSl6enpi06ZN8eCDD8aqVati8eLFx/dRAQDDWlmWZdmxnGDChAnxve99L66++ur4yEc+Eg8//HBcffXVERHx2muvxTnnnBNtbW0xa9asePLJJ+MLX/hC7Nq1K2prayMiYuXKlXHzzTfHW2+9FZWVlR/qe3Z1dUV1dXV0dnZGLpc7lvFHjKkLnxjqEUr25rKmoR4BgEFUyvN3v1+DcujQoVi9enXs27cv8vl8tLe3x4EDB6KhoaF4zNlnnx1TpkyJtra2iIhoa2uL8847rxgnERGNjY3R1dVVvApzNN3d3dHV1dXnBgCMXCUHypYtW+KUU06Jqqqq+NrXvhaPPvpoTJs2LQqFQlRWVsb48eP7HF9bWxuFQiEiIgqFQp84Obz/8L7fZunSpVFdXV281dfXlzo2ADCMlBwoH//4x2Pz5s3x/PPPx/XXXx/Nzc3xyiuvDMRsRYsWLYrOzs7ibefOnQP6/QCAoVVR6h0qKyvjYx/7WERETJ8+PV588cX427/927jmmmuip6cn9uzZ0+cqSkdHR9TV1UVERF1dXbzwwgt9znf4XT6HjzmaqqqqqKqqKnVUAGCYOubfg9Lb2xvd3d0xffr0OOmkk2LDhg3Ffdu2bYsdO3ZEPp+PiIh8Ph9btmyJ3bt3F49Zv3595HK5mDZt2rGOAgCMECVdQVm0aFFceumlMWXKlNi7d288/PDD8cwzz8RTTz0V1dXVMW/evGhtbY0JEyZELpeLG264IfL5fMyaNSsiIi655JKYNm1azJ07N5YvXx6FQiFuvfXWaGlpcYUEACgqKVB2794df/EXfxG//vWvo7q6Os4///x46qmn4k//9E8jIuLOO++M8vLymD17dnR3d0djY2Pce++9xfuPGjUq1qxZE9dff33k8/kYO3ZsNDc3x2233XZ8HxUAMKwd8+9BGQp+D8qR/B4UAFI3KL8HBQBgoAgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSUzHUAwAA/TN14RMDdu43lzUN2Lk/DFdQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOSUFytKlS+OP/uiPYty4cVFTUxNXXnllbNu2rc8x+/fvj5aWlpg4cWKccsopMXv27Ojo6OhzzI4dO6KpqSnGjBkTNTU1cdNNN8XBgweP/dEAACNCSYGycePGaGlpieeeey7Wr18fBw4ciEsuuST27dtXPGbBggXx+OOPxyOPPBIbN26MXbt2xVVXXVXcf+jQoWhqaoqenp7YtGlTPPjgg7Fq1apYvHjx8XtUAMCwVpZlWdbfO7/11ltRU1MTGzdujD/+4z+Ozs7O+MhHPhIPP/xwXH311RER8dprr8U555wTbW1tMWvWrHjyySfjC1/4QuzatStqa2sjImLlypVx8803x1tvvRWVlZW/8/t2dXVFdXV1dHZ2Ri6X6+/4I8rUhU8M9Qgle3NZ01CPADCsDeTf/QPxd3Qpz9/H9BqUzs7OiIiYMGFCRES0t7fHgQMHoqGhoXjM2WefHVOmTIm2traIiGhra4vzzjuvGCcREY2NjdHV1RVbt2496vfp7u6Orq6uPjcAYOTqd6D09vbGjTfeGBdddFGce+65ERFRKBSisrIyxo8f3+fY2traKBQKxWP+f5wc3n9439EsXbo0qquri7f6+vr+jg0ADAP9DpSWlpb493//91i9evXxnOeoFi1aFJ2dncXbzp07B/x7AgBDp6I/d5o/f36sWbMmnn322Tj99NOL2+vq6qKnpyf27NnT5ypKR0dH1NXVFY954YUX+pzv8Lt8Dh/zflVVVVFVVdWfUQGAYaikKyhZlsX8+fPj0UcfjaeffjrOOOOMPvunT58eJ510UmzYsKG4bdu2bbFjx47I5/MREZHP52PLli2xe/fu4jHr16+PXC4X06ZNO5bHAgCMECVdQWlpaYmHH344/uVf/iXGjRtXfM1IdXV1nHzyyVFdXR3z5s2L1tbWmDBhQuRyubjhhhsin8/HrFmzIiLikksuiWnTpsXcuXNj+fLlUSgU4tZbb42WlhZXSQCAiCgxUO67776IiPiTP/mTPtsfeOCB+MpXvhIREXfeeWeUl5fH7Nmzo7u7OxobG+Pee+8tHjtq1KhYs2ZNXH/99ZHP52Ps2LHR3Nwct91227E9EgBgxCgpUD7Mr0wZPXp0rFixIlasWPFbj/noRz8aa9euLeVbAwAnEJ/FAwAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEByBAoAkByBAgAkp2KoBziRTF34xFCPAADDgisoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQHIECACRHoAAAyREoAEBySg6UZ599Ni6//PKYPHlylJWVxWOPPdZnf5ZlsXjx4pg0aVKcfPLJ0dDQEP/5n//Z55h33nkn5syZE7lcLsaPHx/z5s2Ld99995geCAAwcpQcKPv27YtPfOITsWLFiqPuX758edx9992xcuXKeP7552Ps2LHR2NgY+/fvLx4zZ86c2Lp1a6xfvz7WrFkTzz77bFx33XX9fxQAwIhSUeodLr300rj00kuPui/Lsrjrrrvi1ltvjSuuuCIiIv7xH/8xamtr47HHHosvfelL8eqrr8a6devixRdfjBkzZkRExD333BOXXXZZfP/734/Jkycfw8MBAEaC4/oalO3bt0ehUIiGhobiturq6pg5c2a0tbVFRERbW1uMHz++GCcREQ0NDVFeXh7PP//8Uc/b3d0dXV1dfW4AwMh1XAOlUChERERtbW2f7bW1tcV9hUIhampq+uyvqKiICRMmFI95v6VLl0Z1dXXxVl9ffzzHBgASMyzexbNo0aLo7Ows3nbu3DnUIwEAA+i4BkpdXV1ERHR0dPTZ3tHRUdxXV1cXu3fv7rP/4MGD8c477xSPeb+qqqrI5XJ9bgDAyHVcA+WMM86Iurq62LBhQ3FbV1dXPP/885HP5yMiIp/Px549e6K9vb14zNNPPx29vb0xc+bM4zkOADBMlfwunnfffTdef/314tfbt2+PzZs3x4QJE2LKlClx4403xne+8534/d///TjjjDPiW9/6VkyePDmuvPLKiIg455xz4s/+7M/i2muvjZUrV8aBAwdi/vz58aUvfck7eAAYcaYufGKoRxiWSg6UX/ziF/G5z32u+HVra2tERDQ3N8eqVaviG9/4Ruzbty+uu+662LNnT3z605+OdevWxejRo4v3eeihh2L+/Pnx+c9/PsrLy2P27Nlx9913H4eHAwCMBGVZlmVDPUSpurq6orq6Ojo7O4fV61FUdF9vLmsa6hEABtxw/bt/IP6OLuX5e1i8iwcAOLGU/CMeOF4G6l8VrswADH+uoAAAyREoAEByBAoAkByBAgAkR6AAAMnxLh4AiOH7+0pGKldQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDk+DRjAIYNnzh84nAFBQBIjkABAJIjUACA5HgNCsAJaiBfz/HmsqYBOzcnBldQAIDkCBQAIDl+xMOI47I1wPDnCgoAkBxXUIATykBdYXN1DY4vV1AAgOS4ggIl8K9vgMEhUAASNxw/f2Y4zkxaBAokwDuPhj9PyHB8eQ0KAJAcgQIAJMePeGCE8+MjYDhyBQUASI5AAQCSI1AAgOR4DQrQb95aCwwUV1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDkCBQBIjkABAJIjUACA5AgUACA5AgUASI5AAQCSI1AAgOQIFAAgORVDPUCKfIQ8AAwtV1AAgOQIFAAgOQIFAEiOQAEAkiNQAIDkCBQAIDlDGigrVqyIqVOnxujRo2PmzJnxwgsvDOU4AEAihixQ/umf/ilaW1tjyZIl8dJLL8UnPvGJaGxsjN27dw/VSABAIoYsUO6444649tpr46tf/WpMmzYtVq5cGWPGjIkf/vCHQzUSAJCIIflNsj09PdHe3h6LFi0qbisvL4+GhoZoa2s74vju7u7o7u4uft3Z2RkREV1dXQMyX2/3ewNyXgAYLgbiOfbwObMs+53HDkmgvP3223Ho0KGora3ts722tjZee+21I45funRpfPvb3z5ie319/YDNCAAnsuq7Bu7ce/fujerq6g88Zlh8Fs+iRYuitbW1+HVvb2+88847MXHixCgrKxvCydLT1dUV9fX1sXPnzsjlckM9DmFNUmRN0mI90jNQa5JlWezduzcmT578O48dkkA57bTTYtSoUdHR0dFne0dHR9TV1R1xfFVVVVRVVfXZNn78+IEccdjL5XL+oCfGmqTHmqTFeqRnINbkd105OWxIXiRbWVkZ06dPjw0bNhS39fb2xoYNGyKfzw/FSABAQobsRzytra3R3NwcM2bMiAsvvDDuuuuu2LdvX3z1q18dqpEAgEQMWaBcc8018dZbb8XixYujUCjEBRdcEOvWrTvihbOUpqqqKpYsWXLEj8QYOtYkPdYkLdYjPSmsSVn2Yd7rAwAwiHwWDwCQHIECACRHoAAAyREoAEByBMowtGLFipg6dWqMHj06Zs6cGS+88MJvPfb++++Pz3zmM3HqqafGqaeeGg0NDR94PP1Typr8f6tXr46ysrK48sorB3bAE1Cpa7Jnz55oaWmJSZMmRVVVVfzBH/xBrF27dpCmHflKXY+77rorPv7xj8fJJ58c9fX1sWDBgti/f/8gTTuyPfvss3H55ZfH5MmTo6ysLB577LHfeZ9nnnkmPvnJT0ZVVVV87GMfi1WrVg34nJExrKxevTqrrKzMfvjDH2Zbt27Nrr322mz8+PFZR0fHUY//8pe/nK1YsSJ7+eWXs1dffTX7yle+klVXV2f/9V//NciTj1ylrslh27dvz37v934v+8xnPpNdccUVgzPsCaLUNenu7s5mzJiRXXbZZdnPf/7zbPv27dkzzzyTbd68eZAnH5lKXY+HHnooq6qqyh566KFs+/bt2VNPPZVNmjQpW7BgwSBPPjKtXbs2u+WWW7Kf/OQnWURkjz766Ace/8Ybb2RjxozJWltbs1deeSW75557slGjRmXr1q0b0DkFyjBz4YUXZi0tLcWvDx06lE2ePDlbunTph7r/wYMHs3HjxmUPPvjgQI14wunPmhw8eDD71Kc+lf3DP/xD1tzcLFCOs1LX5L777svOPPPMrKenZ7BGPKGUuh4tLS3ZxRdf3Gdba2trdtFFFw3onCeiDxMo3/jGN7I//MM/7LPtmmuuyRobGwdwsizzI55hpKenJ9rb26OhoaG4rby8PBoaGqKtre1DneO9996LAwcOxIQJEwZqzBNKf9fktttui5qampg3b95gjHlC6c+a/PSnP418Ph8tLS1RW1sb5557bnz3u9+NQ4cODdbYI1Z/1uNTn/pUtLe3F38M9MYbb8TatWvjsssuG5SZ6autra3P+kVENDY2fujnnf4aFp9mzP96++2349ChQ0f8tt3a2tp47bXXPtQ5br755pg8efIR/7PRP/1Zk5///Ofxgx/8IDZv3jwIE554+rMmb7zxRjz99NMxZ86cWLt2bbz++uvx9a9/PQ4cOBBLliwZjLFHrP6sx5e//OV4++2349Of/nRkWRYHDx6Mr33ta/HNb35zMEbmfQqFwlHXr6urK37zm9/EySefPCDf1xWUE8iyZcti9erV8eijj8bo0aOHepwT0t69e2Pu3Llx//33x2mnnTbU4/B/ent7o6amJv7+7/8+pk+fHtdcc03ccsstsXLlyqEe7YT0zDPPxHe/+924995746WXXoqf/OQn8cQTT8Ttt98+1KMxiFxBGUZOO+20GDVqVHR0dPTZ3tHREXV1dR943+9///uxbNmy+Nd//dc4//zzB3LME0qpa/LLX/4y3nzzzbj88suL23p7eyMioqKiIrZt2xZnnXXWwA49wvXnz8mkSZPipJNOilGjRhW3nXPOOVEoFKKnpycqKysHdOaRrD/r8a1vfSvmzp0bf/mXfxkREeedd17s27cvrrvuurjllluivNy/rQdTXV3dUdcvl8sN2NWTCFdQhpXKysqYPn16bNiwobitt7c3NmzYEPl8/rfeb/ny5XH77bfHunXrYsaMGYMx6gmj1DU5++yzY8uWLbF58+bi7c///M/jc5/7XGzevDnq6+sHc/wRqT9/Ti666KJ4/fXXi7EYEfEf//EfMWnSJHFyjPqzHu+9994REXI4HjMfHzfo8vl8n/WLiFi/fv0HPu8cFwP6ElyOu9WrV2dVVVXZqlWrsldeeSW77rrrsvHjx2eFQiHLsiybO3dutnDhwuLxy5YtyyorK7N//ud/zn79618Xb3v37h2qhzDilLom7+ddPMdfqWuyY8eObNy4cdn8+fOzbdu2ZWvWrMlqamqy73znO0P1EEaUUtdjyZIl2bhx47If//jH2RtvvJH97Gc/y84666zsi1/84lA9hBFl79692csvv5y9/PLLWURkd9xxR/byyy9nv/rVr7Isy7KFCxdmc+fOLR5/+G3GN910U/bqq69mK1as8DZjju6ee+7JpkyZklVWVmYXXnhh9txzzxX3ffazn82am5uLX3/0ox/NIuKI25IlSwZ/8BGslDV5P4EyMEpdk02bNmUzZ87MqqqqsjPPPDP7m7/5m+zgwYODPPXIVcp6HDhwIPvrv/7r7KyzzspGjx6d1dfXZ1//+tez//7v/x78wUegf/u3fzvq88LhNWhubs4++9nPHnGfCy64IKusrMzOPPPM7IEHHhjwOcuyzPUyACAtXoMCACRHoAAAyREoAEByBAoAkByBAgAkR6AAAMkRKABAcgQKAJAcgQIAJEegAADJESgAQHIECgCQnP8BelmWlACjPiEAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from scipy.optimize import linear_sum_assignment\n",
    "import matplotlib.pyplot as plt\n",
    "#Dictionary Comparison\n",
    "smaller_dict_features, _ = smaller_dict.shape\n",
    "larger_dict_features, _ = larger_dict.shape\n",
    "larger_dict = larger_dict.to(device)\n",
    "# Hungary algorithm\n",
    "# Calculate all cosine similarities and store in a 2D array\n",
    "cos_sims = np.zeros((smaller_dict_features, larger_dict_features))\n",
    "for idx, vector in enumerate(smaller_dict):\n",
    "    cos_sims[idx] = torch.nn.functional.cosine_similarity(vector.to(device), larger_dict, dim=1).cpu().numpy()\n",
    "# Convert to a minimization problem\n",
    "cos_sims = 1 - cos_sims\n",
    "# Use the Hungarian algorithm to solve the assignment problem\n",
    "row_ind, col_ind = linear_sum_assignment(cos_sims)\n",
    "# Retrieve the max cosine similarities and corresponding indices\n",
    "max_cosine_similarities = 1 - cos_sims[row_ind, col_ind]\n",
    "\n",
    "# Get the indices of the max cosine similarities in descending order\n",
    "max_indices = np.argsort(max_cosine_similarities)[::-1]\n",
    "max_cosine_similarities[max_indices][:20]\n",
    "print((\"# of features above 0.9:\", (max_cosine_similarities > .9).sum()))\n",
    "# Plot histogram of max_cosine_similarities\n",
    "plt.hist(max_cosine_similarities, bins=20)\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model activations & Dictionary Activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset parquet (/home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n",
      "Loading cached processed dataset at /home/mchorse/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-7d5c4ad56a49701a.arrow\n",
      "                                                                     \r"
     ]
    }
   ],
   "source": [
    "# Downnload dataset\n",
    "from datasets import Dataset, load_dataset\n",
    "dataset_name = \"NeelNanda/pile-10k\"\n",
    "token_amount= 40\n",
    "dataset = load_dataset(dataset_name, split=\"train\").map(\n",
    "    lambda x: model.tokenizer(x['text']),\n",
    "    batched=True,\n",
    ").filter(\n",
    "    lambda x: len(x['input_ids']) > token_amount\n",
    ").map(\n",
    "    lambda x: {'input_ids': x['input_ids'][:token_amount]}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 309/309 [00:13<00:00, 22.72it/s]\n"
     ]
    }
   ],
   "source": [
    "# Now we can use the model to get the activations\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm.auto import tqdm\n",
    "from einops import rearrange\n",
    "# neurons = model.W_in.shape[-1]\n",
    "neurons = model.cfg.d_model\n",
    "datapoints = dataset.num_rows\n",
    "batch_size = 32\n",
    "neuron_activations = torch.zeros((datapoints*token_amount, neurons))\n",
    "dictionary_activations = torch.zeros((datapoints*token_amount, smaller_dict_features))\n",
    "smaller_auto_encoder = autoencoder\n",
    "smaller_auto_encoder.to_device(device)\n",
    "\n",
    "with torch.no_grad(), dataset.formatted_as(\"pt\"):\n",
    "    dl = DataLoader(dataset[\"input_ids\"], batch_size=batch_size)\n",
    "    for i, batch in enumerate(tqdm(dl)):\n",
    "        _, cache = model.run_with_cache(batch.to(device))\n",
    "        batched_neuron_activations = rearrange(cache[cache_name], \"b s n -> (b s) n\" )\n",
    "        neuron_activations[i*batch_size*token_amount:(i+1)*batch_size*token_amount,:] = batched_neuron_activations.cpu()\n",
    "        batched_dictionary_activations = smaller_auto_encoder.encode(batched_neuron_activations)\n",
    "        dictionary_activations[i*batch_size*token_amount:(i+1)*batch_size*token_amount,:] = batched_dictionary_activations.cpu()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Feature Activation Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Useful Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "from circuitsvis.activations import text_neuron_activations\n",
    "# Get the activations for the best dict features\n",
    "def get_feature_datapoints(feature_index, dictionary_activations, dataset, k=10, setting=\"max\"):\n",
    "    best_feature_activations = dictionary_activations[:, feature_index]\n",
    "    # Sort the features by activation, get the indices\n",
    "    if setting==\"max\":\n",
    "        found_indices = torch.argsort(best_feature_activations, descending=True)[:k]\n",
    "    elif setting==\"uniform\":\n",
    "        min_value = torch.min(best_feature_activations)\n",
    "        max_value = torch.max(best_feature_activations)\n",
    "\n",
    "        # Define the number of bins\n",
    "        num_bins = k\n",
    "\n",
    "        # Calculate the bin boundaries as linear interpolation between min and max\n",
    "        bin_boundaries = torch.linspace(min_value, max_value, num_bins + 1)\n",
    "\n",
    "        # Assign each activation to its respective bin\n",
    "        bins = torch.bucketize(best_feature_activations, bin_boundaries)\n",
    "\n",
    "        # Initialize a list to store the sampled indices\n",
    "        sampled_indices = []\n",
    "\n",
    "        # Sample from each bin\n",
    "        for bin_idx in torch.unique(bins):\n",
    "            # Get the indices corresponding to the current bin\n",
    "            bin_indices = torch.nonzero(bins == bin_idx, as_tuple=False).squeeze(dim=1)\n",
    "            \n",
    "            # Randomly sample from the current bin\n",
    "            sampled_indices.extend(np.random.choice(bin_indices, size=1, replace=False))\n",
    "\n",
    "        # Convert the sampled indices to a PyTorch tensor & reverse order\n",
    "        found_indices = torch.tensor(sampled_indices).long().flip(dims=[0])\n",
    "    else: # random\n",
    "        # get nonzero indices\n",
    "        nonzero_indices = torch.nonzero(best_feature_activations)[:, 0]\n",
    "        # shuffle\n",
    "        shuffled_indices = nonzero_indices[torch.randperm(nonzero_indices.shape[0])]\n",
    "        found_indices = shuffled_indices[:k]\n",
    "    datapoint_indices =[np.unravel_index(i, (datapoints, token_amount)) for i in found_indices]\n",
    "    text_list = []\n",
    "    full_text = []\n",
    "    token_list = []\n",
    "    full_token_list = []\n",
    "    for md, s_ind in datapoint_indices:\n",
    "        md = int(md)\n",
    "        s_ind = int(s_ind)\n",
    "        full_tok = torch.tensor(dataset[md][\"input_ids\"])\n",
    "        full_text.append(model.tokenizer.decode(full_tok))\n",
    "        tok = dataset[md][\"input_ids\"][:s_ind+1]\n",
    "        text = model.tokenizer.decode(tok)\n",
    "        text_list.append(text)\n",
    "        token_list.append(tok)\n",
    "        full_token_list.append(full_tok)\n",
    "    return text_list, full_text, token_list, full_token_list\n",
    "\n",
    "def get_neuron_activation(token, feature, model, setting=\"dictionary_basis\"):\n",
    "    with torch.no_grad():\n",
    "        _, cache = model.run_with_cache(token.to(model.cfg.device))\n",
    "        neuron_act_batch = cache[cache_name]\n",
    "        if setting==\"dictionary_basis\":\n",
    "            neuron_act_batch = rearrange(neuron_act_batch, \"b s n -> (b s) n\" )\n",
    "            act = smaller_auto_encoder.encode(neuron_act_batch)\n",
    "            return act[:, feature].tolist()\n",
    "        else: # neuron/residual basis\n",
    "            return neuron_act_batch[0, :, feature].tolist()\n",
    "\n",
    "def ablate_text(text, feature, model, setting=\"dictionary_basis\"):\n",
    "    if isinstance(text, str):\n",
    "        text = [text]\n",
    "    display_text_list = []\n",
    "    activation_list = []\n",
    "    for t in text:\n",
    "        # Convert text into tokens\n",
    "        if isinstance(t, str): # If the text is a list of tokens\n",
    "            split_text = model.to_str_tokens(t, prepend_bos=False)\n",
    "            tokens = model.to_tokens(t, prepend_bos=False)\n",
    "        else: # t equals tokens\n",
    "            tokens = t\n",
    "            split_text = model.to_str_tokens(t, prepend_bos=False)\n",
    "        seq_size = tokens.shape[1]\n",
    "        if(seq_size == 1): # If the text is a single token, we can't ablate it\n",
    "            continue\n",
    "        original = get_neuron_activation(tokens, feature, model)[-1]\n",
    "        changed_activations = torch.zeros(seq_size, device=device).cpu()\n",
    "        for i in range(seq_size):\n",
    "            # Remove the i'th token from the input\n",
    "            ablated_tokens = torch.cat((tokens[:,:i], tokens[:,i+1:]), dim=1)\n",
    "            changed_activations[i] += get_neuron_activation(ablated_tokens, feature, model, setting)[-1]\n",
    "        changed_activations -= original\n",
    "        display_text_list += [x.replace('\\n', '\\\\newline') for x in split_text] + [\"\\n\"]\n",
    "        activation_list += changed_activations.tolist() + [0.0]\n",
    "    activation_list = torch.tensor(activation_list).reshape(-1,1,1)\n",
    "    return text_neuron_activations(tokens=display_text_list, activations=activation_list)\n",
    "\n",
    "def visualize_text(text, feature, model, setting=\"dictionary_basis\", max_activation = None):\n",
    "    if isinstance(text, str):\n",
    "        text = [text]\n",
    "    if isinstance(feature, int):\n",
    "        feature = [feature]\n",
    "    display_text_list = []\n",
    "    act_list = []\n",
    "    for t in text:\n",
    "        if isinstance(t, str): # If the text is a list of tokens\n",
    "            split_text = model.to_str_tokens(t, prepend_bos=False)\n",
    "            token = model.to_tokens(t, prepend_bos=False)\n",
    "        else: # t are tokens\n",
    "            token = t\n",
    "            split_text = model.to_str_tokens(t, prepend_bos=False)\n",
    "        for f in feature:\n",
    "            display_text_list += [x.replace('\\n', '\\\\newline') for x in split_text] + [\"\\n\"]\n",
    "            act_list += get_neuron_activation(token, f, model, setting) + [0.0]\n",
    "    act_list = torch.tensor(act_list).reshape(-1,1,1)\n",
    "    if(max_activation is not None):\n",
    "        act_list = torch.clamp(act_list, max=max_activation)\n",
    "    return text_neuron_activations(tokens=display_text_list, activations=act_list)\n",
    "# Ablate the feature direction of the tokens\n",
    "# token_list is a list of tokens, convert to tensor of shape (batch_size, seq_len)\n",
    "from einops import rearrange\n",
    "def ablate_feature_direction(tokens, feature, model, autoencoder):\n",
    "    def mlp_ablation_hook(value, hook):\n",
    "        # Rearrange to fit autoencoder\n",
    "        int_val = rearrange(value, 'b s h -> (b s) h')\n",
    "\n",
    "        # Run through the autoencoder\n",
    "        act = autoencoder.encode(int_val)\n",
    "        feature_to_ablate = feature # TODO: bring this out of the function\n",
    "\n",
    "        # Subtract value with feature direction*act_of_feature\n",
    "        dictionary_for_this_autoencoder = autoencoder.get_learned_dict()\n",
    "        feature_direction = torch.outer(act[:, feature_to_ablate].squeeze(), dictionary_for_this_autoencoder[feature_to_ablate].squeeze())\n",
    "        batch, seq_len, hidden_size = value.shape\n",
    "        feature_direction = rearrange(feature_direction, '(b s) h -> b s h', b=batch, s=seq_len)\n",
    "        value -= feature_direction\n",
    "        return value\n",
    "\n",
    "    return model.run_with_hooks(tokens, \n",
    "        fwd_hooks=[(\n",
    "            cache_name, \n",
    "            mlp_ablation_hook\n",
    "            )]\n",
    "        )\n",
    "def add_feature_direction(tokens, feature, model, autoencoder, scalar=1.0):\n",
    "    def residual_add_hook(value, hook):\n",
    "        feature_direction = autoencoder.decoder.weight[:, feature].squeeze()\n",
    "        value += scalar*feature_direction\n",
    "        return value\n",
    "\n",
    "    return model.run_with_hooks(tokens, \n",
    "        fwd_hooks=[(\n",
    "            cache_name,\n",
    "            residual_add_hook\n",
    "            )]\n",
    "        )\n",
    "def ablate_feature_direction_display(text, features=None, setting=\"true_tokens\", verbose=False):\n",
    "\n",
    "    if features==None:\n",
    "        features = torch.tensor([best_feature])\n",
    "    if isinstance(features, int):\n",
    "        features = torch.tensor([features])\n",
    "    if isinstance(features, list):\n",
    "        features = torch.tensor(features)\n",
    "    if isinstance(text, str):\n",
    "        text = [text]\n",
    "    text_list = []\n",
    "    logit_list = []\n",
    "    for t in text:\n",
    "        tokens = model.to_tokens(t, prepend_bos=False)\n",
    "        with torch.no_grad():\n",
    "            original_logits = model(tokens).log_softmax(-1).cpu()\n",
    "            ablated_logits = ablate_feature_direction(tokens, features, model, smaller_auto_encoder).log_softmax(-1).cpu()\n",
    "        diff_logits = ablated_logits  - original_logits# ablated > original -> negative diff\n",
    "        tokens = tokens.cpu()\n",
    "        if setting == \"true_tokens\":\n",
    "            split_text = model.to_str_tokens(t, prepend_bos=False)\n",
    "            gather_tokens = rearrange(tokens[:,1:], \"b s -> b s 1\") # TODO: verify this is correct\n",
    "            # Gather the logits for the true tokens\n",
    "            diff = rearrange(diff_logits[:, :-1].gather(-1,gather_tokens), \"b s n -> (b s n)\")\n",
    "        elif setting == \"max\":\n",
    "            # Negate the diff_logits to see which tokens have the largest effect on the neuron\n",
    "            val, ind = (-1*diff_logits).max(-1)\n",
    "            diff = rearrange(val[:, :-1], \"b s -> (b s)\")\n",
    "            diff*= -1 # Negate the values gathered\n",
    "            split_text = model.to_str_tokens(ind, prepend_bos=False)\n",
    "            gather_tokens = rearrange(ind[:,1:], \"1 s -> 1 s 1\")\n",
    "        split_text = split_text[1:] # Remove the first token since we're not predicting it\n",
    "        if(verbose):\n",
    "            text_list += [x.replace('\\n', '\\\\newline') for x in split_text] + [\"\\n\"]\n",
    "            text_list += [x.replace('\\n', '\\\\newline') for x in split_text] + [\"\\n\"]\n",
    "            orig = rearrange(original_logits[:, :-1].gather(-1, gather_tokens), \"b s n -> (b s n)\")\n",
    "            ablated = rearrange(ablated_logits[:, :-1].gather(-1, gather_tokens), \"b s n -> (b s n)\")\n",
    "            logit_list += orig.tolist() + [0.0]\n",
    "            logit_list += ablated.tolist() + [0.0]\n",
    "        text_list += [x.replace('\\n', '\\\\newline') for x in split_text] + [\"\\n\"]\n",
    "        logit_list += diff.tolist() + [0.0]\n",
    "    logit_list = torch.tensor(logit_list).reshape(-1,1,1)\n",
    "    if verbose:\n",
    "        print(f\"Max & Min logit-diff: {logit_list.max().item():.2f} & {logit_list.min().item():.2f}\")\n",
    "    return text_neuron_activations(tokens=text_list, activations=logit_list)\n",
    "def generate_text(input_text, num_tokens, model, autoencoder, feature, temperature=0.7, setting=\"add\", scalar=1.0):\n",
    "    # Convert input text to tokens\n",
    "    input_ids = model.tokenizer.encode(input_text, return_tensors='pt').to(device)\n",
    "\n",
    "    for _ in range(num_tokens):\n",
    "        # Generate logits\n",
    "        with torch.no_grad():\n",
    "            if(setting==\"add\"):\n",
    "                logits = add_feature_direction(input_ids, feature, model, autoencoder, scalar=scalar)\n",
    "            else:\n",
    "                logits = model(input_ids)\n",
    "\n",
    "        # Apply temperature\n",
    "        logits = logits / temperature\n",
    "\n",
    "        # Sample from the distribution\n",
    "        probs = torch.nn.functional.softmax(logits[:, -1, :], dim=-1)\n",
    "        predicted_token = torch.multinomial(probs, num_samples=1)\n",
    "\n",
    "        # Append predicted token to input_ids\n",
    "        input_ids = torch.cat((input_ids, predicted_token), dim=-1)\n",
    "\n",
    "    # Decode the tokens to text\n",
    "    output_text = model.tokenizer.decode(input_ids[0])\n",
    "\n",
    "    return output_text\n",
    "\n",
    "# Logit Lens\n",
    "def logit_lens(model, best_feature, smaller_dict, layer):\n",
    "    with torch.no_grad():\n",
    "        # There are never-used tokens, which have high norm. We want to ignore these.\n",
    "        bad_ind = (model.W_U.norm(dim=0) > 20)\n",
    "        feature_direction = smaller_dict[best_feature].to(device)\n",
    "        # feature_direction = torch.matmul(feature_direction, model.W_out[layer]) # if MLP\n",
    "        logits = torch.matmul(feature_direction, model.W_U).cpu()\n",
    "    # Don't include bad indices\n",
    "    logits[bad_ind] = -1000\n",
    "    topk_values, topk_indices = torch.topk(logits, 20)\n",
    "    top_text = model.to_str_tokens(topk_indices)\n",
    "    print(f\"{top_text}\")\n",
    "    print(topk_values)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Feature Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = \" I do like a\"\n",
    "split_text = model.to_str_tokens(t, prepend_bos=False)\n",
    "token = model.to_tokens(t, prepend_bos=False)\n",
    "_, cache = model.run_with_cache(token.to(model.cfg.device))\n",
    "neuron_act_batch = cache[cache_name]\n",
    "_, act = smaller_auto_encoder(neuron_act_batch)\n",
    "v, i = act[0, -1, :].topk(10)\n",
    "\n",
    "print(\"Activations:\",[round(val,2) for val in v.tolist()])\n",
    "print(\"Feature_ids\", i.tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Feature Interp\n",
    "Investigate the example sentences the activate this feature.\n",
    "\n",
    "Max: show max activating (tokens,contexts)\n",
    "\n",
    "Uniform: Show range of activations from each bin (e.g. sample an example from 1-2, 2-3, etc). \n",
    "[Note: if a feature is monosemantic, then the full range of activations should be that feature, not just max-activating ones]\n",
    "\n",
    "Full_text: shows the full text example\n",
    "\n",
    "Text_list: shows up to the most activating example (try w/ max activating on a couple of examples to see)\n",
    "\n",
    "ablate_text: remove the context one token at a time, and show the decrease/increase in activation of that feature\n",
    "\n",
    "ablate_feature_direction: removes feature direction from model's activation mid-inference, showing the logit diff in the output for every token.\n",
    "\n",
    "logit_lens: show the logit lens for that feature. If matches ablate_feature_direction, then the computation path is through the residual stream, else, it's through future layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Feature index: 11\n",
      "MCS: 0.7368124723434448\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div id=\"circuits-vis-67ff7915-69ff\" style=\"margin: 15px 0;\"/>\n",
       "    <script crossorigin type=\"module\">\n",
       "    import { render, TextNeuronActivations } from \"https://unpkg.com/circuitsvis@1.40.0/dist/cdn/esm.js\";\n",
       "    render(\n",
       "      \"circuits-vis-67ff7915-69ff\",\n",
       "      TextNeuronActivations,\n",
       "      {\"tokens\": [\"Description\", \":\", \" This\", \" project\", \" proposes\", \" two\", \" outreach\", \" programs\", \":\", \" 1\", \")\", \" outreach\", \" to\", \" the\", \" \\\"\", \"Community\", \"\\\"\", \" through\", \" an\", \" effort\", \" with\", \" people\", \" in\", \" Mexico\", \" and\", \" 2\", \")\", \" an\", \" outreach\", \" to\", \" \\\"\", \"\\n\", \"After\", \" acknowledging\", \" that\", \" Santa\", \" Barbara\", \"'s\", \" high\", \"-\", \"density\", \" housing\", \" program\", \" has\", \" failed\", \" in\", \" many\", \" areas\", \",\", \" the\", \" city\", \" is\", \" turning\", \" from\", \" Mil\", \"pas\", \" Street\", \" to\", \" downtown\", \" as\", \" a\", \" place\", \" to\", \" build\", \"\\n\", \"This\", \" is\", \" a\", \" limitation\", \" imposed\", \" by\", \" jet\", \"ty\", \" server\", \" where\", \" the\", \" HTTP\", \" response\", \" header\", \" is\", \" hard\", \" coded\", \" to\", \" 64\", \"kb\", \" in\", \" PH\", \"D\", \" 1\", \".\", \"0\", \".\", \"1\", \" and\", \" 64\", \"kb\", \" is\", \" the\", \"\\n\", \"Republic\", \"ans\", \" are\", \" getting\", \" to\", \" the\", \" rock\", \"-\", \"bottom\", \" of\", \" it\", \".\", \"\\\\newline\", \"\\\\newline\", \"They\", \" have\", \" defended\", \"\\n\", \"I\", \" got\", \" a\", \" wake\", \" up\", \" call\", \",\", \" I\", \" got\", \" to\", \" make\", \" this\", \" work\", \"Cause\", \" if\", \" we\", \" don\", \"\\u00b4\", \"t\", \" we\", \"\\u00b4\", \"re\", \" left\", \" with\", \" nothing\", \" and\", \" that\", \"\\u00b4\", \"\\n\", \"Q\", \":\", \"\\\\newline\", \"\\\\newline\", \"HTTP\", \" Response\", \" in\", \" Android\", \" -\", \" Network\", \"On\", \"Main\", \"Thread\", \"Exception\", \"\\\\newline\", \"\\\\newline\", \"I\", \" want\", \" to\", \" check\", \" the\", \" HTTP\", \" response\", \" of\", \" a\", \" certain\", \" URL\", \" before\", \" loading\", \" into\", \" a\", \" web\", \"view\", \".\", \" I\", \" only\", \" want\", \" to\", \" load\", \" web\", \"\\n\", \"Map\", \" Of\", \" London\", \" 32\", \" Borough\", \"s\", \" Ne\", \"ighborhood\", \"s\", \" In\", \"\\\\newline\", \"\\\\newline\", \"Map\", \" Of\", \" London\", \" 32\", \" Borough\", \"s\", \" Ne\", \"ighborhood\", \"s\", \" In\", \"\\\\newline\", \"\\\\newline\", \"Map\", \" Of\", \" London\", \" 32\", \" Borough\", \"\\n\", \"<\", \"header\", \" class\", \"=\\\"\", \"header\", \"-\", \"wrapper\", \"\\\">\", \"\\\\newline\\\\newline\", \"  \", \"<\", \"nav\", \" class\", \"=\\\"\", \"inner\", \"\\\">\", \"\\\\newline\", \"    \", \"<\", \"div\", \" class\", \"=\\\"\", \"title\", \"\\\">\", \"\\\\newline\", \"      \", \"<\", \"a\", \" href\", \"=\\\"/\", \"\\\">\", \"\\\\newline\", \"        \", \"<\", \"img\", \" class\", \"=\\\"\", \"logo\", \"\\n\", \"Poly\", \"(\", \"alkyl\", \"cy\", \"ano\", \"acrylate\", \")\", \" nanoc\", \"aps\", \"ules\", \":\", \" physic\", \"ochemical\", \" characterization\", \" and\", \" mechanism\", \" of\", \" formation\", \".\", \"\\\\newline\", \"Nan\", \"oc\", \"aps\", \"ules\", \" of\", \" poly\", \"(\", \"is\", \"ob\", \"ut\", \"yl\", \"cy\", \"\\n\", \"See\", \",\", \" How\", \" Effective\", \" is\", \" Milk\", \" For\", \" Cold\", \" So\", \"res\", \" Treatment\", \"\\u2026.\", \"\\\\newline\", \"\\\\newline\", \"You\", \" are\", \" definitely\", \" going\", \" to\", \" use\", \" milk\", \" for\", \" cold\", \" so\", \"res\", \" after\", \" learning\", \" the\", \" awesome\", \" benefits\", \" of\", \" milk\", \" for\", \" cold\", \" so\", \"\\n\", \"Q\", \":\", \"\\\\newline\", \"\\\\newline\", \"How\", \" to\", \" say\", \" \\\"\", \"I\", \" look\", \"\\\"\", \" as\", \" in\", \" \\\"\", \"I\", \" look\", \" stupid\", \"\\\"\", \" in\", \" French\", \"?\", \"\\\\newline\", \"\\\\newline\", \"How\", \" do\", \" you\", \" say\", \" \\\"\", \"I\", \" look\", \"\\\"\", \" as\", \" in\", \" \\\"\", \"I\", \"\\n\"], \"activations\": [[[-0.02271796017885208]], [[0.00012589991092681885]], [[-0.054130859673023224]], [[-0.08781802654266357]], [[0.17008452117443085]], [[0.2159292846918106]], [[0.14494235813617706]], [[0.19824208319187164]], [[0.01864047348499298]], [[-0.1748228669166565]], [[-0.22724087536334991]], [[-0.46498429775238037]], [[0.1333581954240799]], [[0.3213043808937073]], [[0.21151067316532135]], [[0.10226581990718842]], [[0.2084466964006424]], [[0.20415140688419342]], [[0.16048340499401093]], [[0.10236239433288574]], [[0.3660184144973755]], [[0.3913286328315735]], [[0.306435763835907]], [[0.22170202434062958]], [[-0.027727141976356506]], [[-0.4888033866882324]], [[-0.04541388899087906]], [[0.09957514703273773]], [[-0.017748981714248657]], [[0.40519285202026367]], [[0.8116349577903748]], [[0.0]], [[0.02264060080051422]], [[-0.1141149252653122]], [[-0.19982792437076569]], [[-0.2240270972251892]], [[-0.29898205399513245]], [[-0.20517103374004364]], [[-0.09148988872766495]], [[-0.02149951457977295]], [[-0.15469488501548767]], [[-0.18705129623413086]], [[-0.22260244190692902]], [[-0.1511133462190628]], [[-0.08827223628759384]], [[-0.01819281280040741]], [[0.03795833885669708]], [[0.05099277198314667]], [[-0.07216031849384308]], [[0.24821840226650238]], [[0.3586735129356384]], [[0.3206532597541809]], [[0.43407344818115234]], [[0.3122594356536865]], [[0.06416518986225128]], [[0.2361217588186264]], [[0.23877419531345367]], [[0.20951290428638458]], [[0.17694945633411407]], [[0.36777275800704956]], [[0.3675166368484497]], [[0.43706411123275757]], [[0.551352322101593]], [[0.5651237368583679]], [[0.0]], [[-0.08812133967876434]], [[-0.19097334146499634]], [[-0.24658958613872528]], [[-0.15666425228118896]], [[0.03504219651222229]], [[0.14586125314235687]], [[0.15120451152324677]], [[0.16606469452381134]], [[0.21782274544239044]], [[0.4346509575843811]], [[0.46052658557891846]], [[0.2850242853164673]], [[0.288537859916687]], [[0.42769157886505127]], [[0.35478824377059937]], [[0.20123787224292755]], [[0.2989463806152344]], [[0.33083271980285645]], [[0.185587540268898]], [[0.2549557685852051]], [[0.27951139211654663]], [[0.12208981812000275]], [[0.18646495044231415]], [[0.1542532593011856]], [[0.041822731494903564]], [[0.16138391196727753]], [[0.0888642966747284]], [[0.17317746579647064]], [[0.3594822883605957]], [[0.09364363551139832]], [[0.2686281204223633]], [[0.2604776620864868]], [[0.3635105490684509]], [[0.0]], [[0.08126188814640045]], [[0.060776785016059875]], [[0.07972384989261627]], [[0.09449535608291626]], [[0.008536145091056824]], [[-0.019662566483020782]], [[0.0605793297290802]], [[0.06206357479095459]], [[0.09833656251430511]], [[0.05748322606086731]], [[0.06098593771457672]], [[0.030033409595489502]], [[0.04356478154659271]], [[0.09424091875553131]], [[0.13979803025722504]], [[0.18847741186618805]], [[0.25771790742874146]], [[0.0]], [[-0.05830682814121246]], [[-0.17235779762268066]], [[-0.24779918789863586]], [[-0.2744314670562744]], [[-0.281788170337677]], [[-0.30417677760124207]], [[-0.3063863217830658]], [[-0.2750124931335449]], [[-0.39270585775375366]], [[-0.4331475496292114]], [[-0.32983726263046265]], [[-0.4339643716812134]], [[-0.47068560123443604]], [[-0.1896301954984665]], [[-0.15768881142139435]], [[-0.07852374762296677]], [[-0.08830097317695618]], [[-0.14143548905849457]], [[-0.10080580413341522]], [[-0.0015527158975601196]], [[0.019663840532302856]], [[0.1377822309732437]], [[0.20735980570316315]], [[0.19872520864009857]], [[0.4090338349342346]], [[0.05514697730541229]], [[0.19672171771526337]], [[0.036218926310539246]], [[0.0]], [[-0.07803662121295929]], [[-0.1546819806098938]], [[-0.06160996854305267]], [[-0.14303350448608398]], [[-0.08734390884637833]], [[-0.1002034917473793]], [[-0.06285814195871353]], [[-0.03207971155643463]], [[-0.10069897770881653]], [[-0.04699814319610596]], [[0.0004656761884689331]], [[-0.0393843874335289]], [[-0.07664371281862259]], [[-0.11653220653533936]], [[-0.04444471001625061]], [[-0.13843977451324463]], [[-0.14171122014522552]], [[0.016744345426559448]], [[-0.03416643291711807]], [[-0.1589541733264923]], [[-0.12328694015741348]], [[-0.012346580624580383]], [[-0.30981823801994324]], [[-0.1721029132604599]], [[-0.06545811891555786]], [[-0.10482733696699142]], [[-0.1604556143283844]], [[-0.05993349850177765]], [[0.028431400656700134]], [[-0.06751053035259247]], [[-0.0493038147687912]], [[-0.12129632383584976]], [[-0.11080619692802429]], [[-0.13755576312541962]], [[-0.18909139931201935]], [[-0.09998013824224472]], [[0.027714908123016357]], [[-0.06737224757671356]], [[-0.055356465280056]], [[-0.2659788429737091]], [[0.0]], [[-0.02658993750810623]], [[0.008966311812400818]], [[0.02251298725605011]], [[0.02871297299861908]], [[0.09627316892147064]], [[0.12784089148044586]], [[0.14492593705654144]], [[0.19973234832286835]], [[0.20565585792064667]], [[0.14520283043384552]], [[0.030672520399093628]], [[0.03722512722015381]], [[0.01632741093635559]], [[-0.08986619114875793]], [[-0.26859599351882935]], [[0.26751208305358887]], [[-0.6227580308914185]], [[0.27438509464263916]], [[0.16129513084888458]], [[-0.2668743133544922]], [[0.3629252314567566]], [[-0.02175326645374298]], [[-0.04671717435121536]], [[0.1555154174566269]], [[-0.17212781310081482]], [[-0.15234170854091644]], [[-0.244182288646698]], [[0.30335891246795654]], [[-0.5666812062263489]], [[0.0]], [[-0.04483465105295181]], [[-0.04473121464252472]], [[-0.1324356347322464]], [[-0.3755863308906555]], [[-0.7354291081428528]], [[-0.579806387424469]], [[-0.5413438677787781]], [[-0.25827324390411377]], [[-0.030796609818935394]], [[-0.1376980096101761]], [[-0.30647438764572144]], [[-0.359840989112854]], [[-0.24390657246112823]], [[-0.3858455717563629]], [[-0.5484462976455688]], [[-0.13769188523292542]], [[-0.1911497861146927]], [[-0.07870756089687347]], [[-0.29733553528785706]], [[-0.418548047542572]], [[-0.3509906232357025]], [[-0.5097808837890625]], [[-0.6101454496383667]], [[-0.18502973020076752]], [[-0.19842863082885742]], [[-0.13593026995658875]], [[-0.3077087998390198]], [[-0.36224183440208435]], [[-0.35007572174072266]], [[-0.44160324335098267]], [[-0.25998806953430176]], [[-0.27355241775512695]], [[-0.23570024967193604]], [[-0.3526879549026489]], [[-0.522918701171875]], [[-0.4911043643951416]], [[-0.6570993661880493]], [[-0.7069520354270935]], [[0.0]], [[-0.07919597625732422]], [[-0.12195703387260437]], [[-0.10125485062599182]], [[-0.11456999182701111]], [[-0.11899442970752716]], [[-0.1566152572631836]], [[-0.0631595104932785]], [[-0.06126046180725098]], [[0.0009112954139709473]], [[-0.023545272648334503]], [[0.04881805181503296]], [[0.08608727157115936]], [[0.18730901181697845]], [[0.23531033098697662]], [[0.27589917182922363]], [[0.32134705781936646]], [[0.2616477608680725]], [[0.3468858003616333]], [[0.02747330069541931]], [[0.20727433264255524]], [[0.14335615932941437]], [[-0.19637024402618408]], [[-0.32197707891464233]], [[0.11814276874065399]], [[0.12126688659191132]], [[-0.10470449924468994]], [[-0.36079341173171997]], [[-0.2696393132209778]], [[-0.22029678523540497]], [[-0.411615788936615]], [[-0.47793489694595337]], [[-1.0242382287979126]], [[0.0]], [[-0.08102145791053772]], [[-0.08471179008483887]], [[-0.15655893087387085]], [[0.020739972591400146]], [[0.0334654301404953]], [[0.01180654764175415]], [[0.15850277245044708]], [[0.2461603432893753]], [[0.15277160704135895]], [[0.2482084184885025]], [[0.27483320236206055]], [[0.08626210689544678]], [[0.0033192038536071777]], [[-0.005029588937759399]], [[-0.03099006414413452]], [[0.05564670264720917]], [[-0.06821338832378387]], [[-0.1743704229593277]], [[-0.14343582093715668]], [[-0.0228898823261261]], [[-0.1583717167377472]], [[0.06934453547000885]], [[-0.30856001377105713]], [[-0.5183477997779846]], [[0.2440730780363083]], [[0.03383298218250275]], [[-0.09727828949689865]], [[-0.1665688306093216]], [[-0.1382681131362915]], [[-0.13805636763572693]], [[-0.17230486869812012]], [[-0.21519702672958374]], [[-0.11944103240966797]], [[-0.5056799054145813]], [[-1.168542742729187]], [[0.0]], [[-0.07803662121295929]], [[-0.1546819806098938]], [[-0.06160996854305267]], [[-0.14303350448608398]], [[-0.07243193686008453]], [[-0.17076128721237183]], [[-0.14328770339488983]], [[-0.22298112511634827]], [[-0.29908308386802673]], [[-0.26053643226623535]], [[-0.13854646682739258]], [[-0.13855187594890594]], [[-0.20685423910617828]], [[-0.23595565557479858]], [[-0.603787362575531]], [[-0.6083805561065674]], [[-0.2908019423484802]], [[-0.31742680072784424]], [[-0.17911767959594727]], [[-0.11052439361810684]], [[-0.2125479280948639]], [[-0.28744760155677795]], [[-0.19026851654052734]], [[-0.2838134169578552]], [[-0.21655139327049255]], [[-0.4087596535682678]], [[-0.47511571645736694]], [[-0.3050718605518341]], [[-1.3514420986175537]], [[-0.7392570972442627]], [[-0.8150429725646973]], [[-0.7708994746208191]], [[-0.6117175817489624]], [[-0.4687038064002991]], [[-1.408411979675293]], [[0.0]]], \"firstDimensionName\": \"Layer\", \"secondDimensionName\": \"Neuron\"}\n",
       "    )\n",
       "    </script>"
      ],
      "text/plain": [
       "<circuitsvis.utils.render.RenderedHTML at 0x7f1d7027d570>"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# N = 500\n",
    "# best_feature = int(max_indices[N])\n",
    "best_feature = 11 # Change this one for global index (N is sorted by MCS)\n",
    "\n",
    "print(f\"Feature index: {best_feature}\")\n",
    "print(f\"MCS: {max_cosine_similarities[best_feature]}\")\n",
    "text_list, full_text, token_list, full_token_list = get_feature_datapoints(best_feature, dictionary_activations, dataset, setting=\"uniform\")\n",
    "# text_list, full_text, token_list, full_token_list = get_feature_datapoints(best_feature, dictionary_activations, dataset, setting=\"max\")\n",
    "visualize_text(text_list, best_feature, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div id=\"circuits-vis-1a23f324-ec24\" style=\"margin: 15px 0;\"/>\n",
       "    <script crossorigin type=\"module\">\n",
       "    import { render, TextNeuronActivations } from \"https://unpkg.com/circuitsvis@1.40.0/dist/cdn/esm.js\";\n",
       "    render(\n",
       "      \"circuits-vis-1a23f324-ec24\",\n",
       "      TextNeuronActivations,\n",
       "      {\"tokens\": [\"Description\", \":\", \" This\", \" project\", \" proposes\", \" two\", \" outreach\", \" programs\", \":\", \" 1\", \")\", \" outreach\", \" to\", \" the\", \" \\\"\", \"Community\", \"\\\"\", \" through\", \" an\", \" effort\", \" with\", \" people\", \" in\", \" Mexico\", \" and\", \" 2\", \")\", \" an\", \" outreach\", \" to\", \" \\\"\", \"\\n\", \"After\", \" acknowledging\", \" that\", \" Santa\", \" Barbara\", \"'s\", \" high\", \"-\", \"density\", \" housing\", \" program\", \" has\", \" failed\", \" in\", \" many\", \" areas\", \",\", \" the\", \" city\", \" is\", \" turning\", \" from\", \" Mil\", \"pas\", \" Street\", \" to\", \" downtown\", \" as\", \" a\", \" place\", \" to\", \" build\", \"\\n\", \"This\", \" is\", \" a\", \" limitation\", \" imposed\", \" by\", \" jet\", \"ty\", \" server\", \" where\", \" the\", \" HTTP\", \" response\", \" header\", \" is\", \" hard\", \" coded\", \" to\", \" 64\", \"kb\", \" in\", \" PH\", \"D\", \" 1\", \".\", \"0\", \".\", \"1\", \" and\", \" 64\", \"kb\", \" is\", \" the\", \"\\n\", \"Republic\", \"ans\", \" are\", \" getting\", \" to\", \" the\", \" rock\", \"-\", \"bottom\", \" of\", \" it\", \".\", \"\\\\newline\", \"\\\\newline\", \"They\", \" have\", \" defended\", \"\\n\", \"I\", \" got\", \" a\", \" wake\", \" up\", \" call\", \",\", \" I\", \" got\", \" to\", \" make\", \" this\", \" work\", \"Cause\", \" if\", \" we\", \" don\", \"\\u00b4\", \"t\", \" we\", \"\\u00b4\", \"re\", \" left\", \" with\", \" nothing\", \" and\", \" that\", \"\\u00b4\", \"\\n\", \"Q\", \":\", \"\\\\newline\", \"\\\\newline\", \"HTTP\", \" Response\", \" in\", \" Android\", \" -\", \" Network\", \"On\", \"Main\", \"Thread\", \"Exception\", \"\\\\newline\", \"\\\\newline\", \"I\", \" want\", \" to\", \" check\", \" the\", \" HTTP\", \" response\", \" of\", \" a\", \" certain\", \" URL\", \" before\", \" loading\", \" into\", \" a\", \" web\", \"view\", \".\", \" I\", \" only\", \" want\", \" to\", \" load\", \" web\", \"\\n\", \"Map\", \" Of\", \" London\", \" 32\", \" Borough\", \"s\", \" Ne\", \"ighborhood\", \"s\", \" In\", \"\\\\newline\", \"\\\\newline\", \"Map\", \" Of\", \" London\", \" 32\", \" Borough\", \"s\", \" Ne\", \"ighborhood\", \"s\", \" In\", \"\\\\newline\", \"\\\\newline\", \"Map\", \" Of\", \" London\", \" 32\", \" Borough\", \"\\n\", \"<\", \"header\", \" class\", \"=\\\"\", \"header\", \"-\", \"wrapper\", \"\\\">\", \"\\\\newline\\\\newline\", \"  \", \"<\", \"nav\", \" class\", \"=\\\"\", \"inner\", \"\\\">\", \"\\\\newline\", \"    \", \"<\", \"div\", \" class\", \"=\\\"\", \"title\", \"\\\">\", \"\\\\newline\", \"      \", \"<\", \"a\", \" href\", \"=\\\"/\", \"\\\">\", \"\\\\newline\", \"        \", \"<\", \"img\", \" class\", \"=\\\"\", \"logo\", \"\\n\", \"Poly\", \"(\", \"alkyl\", \"cy\", \"ano\", \"acrylate\", \")\", \" nanoc\", \"aps\", \"ules\", \":\", \" physic\", \"ochemical\", \" characterization\", \" and\", \" mechanism\", \" of\", \" formation\", \".\", \"\\\\newline\", \"Nan\", \"oc\", \"aps\", \"ules\", \" of\", \" poly\", \"(\", \"is\", \"ob\", \"ut\", \"yl\", \"cy\", \"\\n\", \"See\", \",\", \" How\", \" Effective\", \" is\", \" Milk\", \" For\", \" Cold\", \" So\", \"res\", \" Treatment\", \"\\u2026.\", \"\\\\newline\", \"\\\\newline\", \"You\", \" are\", \" definitely\", \" going\", \" to\", \" use\", \" milk\", \" for\", \" cold\", \" so\", \"res\", \" after\", \" learning\", \" the\", \" awesome\", \" benefits\", \" of\", \" milk\", \" for\", \" cold\", \" so\", \"\\n\", \"Q\", \":\", \"\\\\newline\", \"\\\\newline\", \"How\", \" to\", \" say\", \" \\\"\", \"I\", \" look\", \"\\\"\", \" as\", \" in\", \" \\\"\", \"I\", \" look\", \" stupid\", \"\\\"\", \" in\", \" French\", \"?\", \"\\\\newline\", \"\\\\newline\", \"How\", \" do\", \" you\", \" say\", \" \\\"\", \"I\", \" look\", \"\\\"\", \" as\", \" in\", \" \\\"\", \"I\", \"\\n\"], \"activations\": [[[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.2615354061126709]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]]], \"firstDimensionName\": \"Layer\", \"secondDimensionName\": \"Neuron\"}\n",
       "    )\n",
       "    </script>"
      ],
      "text/plain": [
       "<circuitsvis.utils.render.RenderedHTML at 0x7f1d459335b0>"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ablate_text(text_list, best_feature, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div id=\"circuits-vis-b076d65d-ad2f\" style=\"margin: 15px 0;\"/>\n",
       "    <script crossorigin type=\"module\">\n",
       "    import { render, TextNeuronActivations } from \"https://unpkg.com/circuitsvis@1.40.0/dist/cdn/esm.js\";\n",
       "    render(\n",
       "      \"circuits-vis-b076d65d-ad2f\",\n",
       "      TextNeuronActivations,\n",
       "      {\"tokens\": [\":\", \" This\", \" project\", \" proposes\", \" two\", \" outreach\", \" programs\", \":\", \" 1\", \")\", \" outreach\", \" to\", \" the\", \" \\\"\", \"Community\", \"\\\"\", \" through\", \" an\", \" effort\", \" with\", \" people\", \" in\", \" Mexico\", \" and\", \" 2\", \")\", \" an\", \" outreach\", \" to\", \" \\\"\", \"Indust\", \"ry\", \"\\\"\", \" through\", \" an\", \" effort\", \" in\", \" economic\", \" development\", \"\\n\", \" acknowledging\", \" that\", \" Santa\", \" Barbara\", \"'s\", \" high\", \"-\", \"density\", \" housing\", \" program\", \" has\", \" failed\", \" in\", \" many\", \" areas\", \",\", \" the\", \" city\", \" is\", \" turning\", \" from\", \" Mil\", \"pas\", \" Street\", \" to\", \" downtown\", \" as\", \" a\", \" place\", \" to\", \" build\", \" rental\", \" apartments\", \".\", \"\\\\newline\", \"\\\\newline\", \"Santa\", \" Barbara\", \"'s\", \"\\n\", \" is\", \" a\", \" limitation\", \" imposed\", \" by\", \" jet\", \"ty\", \" server\", \" where\", \" the\", \" HTTP\", \" response\", \" header\", \" is\", \" hard\", \" coded\", \" to\", \" 64\", \"kb\", \" in\", \" PH\", \"D\", \" 1\", \".\", \"0\", \".\", \"1\", \" and\", \" 64\", \"kb\", \" is\", \" the\", \" default\", \" setting\", \" for\", \" PH\", \"D\", \" 1\", \".\", \"\\n\", \"ans\", \" are\", \" getting\", \" to\", \" the\", \" rock\", \"-\", \"bottom\", \" of\", \" it\", \".\", \"\\\\newline\", \"\\\\newline\", \"They\", \" have\", \" defended\", \" President\", \" Trump\", \" throughout\", \" the\", \" public\", \" impeachment\", \" hearings\", \" by\", \" arguing\", \" his\", \" gang\", \"ster\", \" efforts\", \" to\", \" force\", \" a\", \" Ukrainian\", \" investigation\", \" into\", \" its\", \" (\", \"imag\", \"ined\", \"\\n\", \" got\", \" a\", \" wake\", \" up\", \" call\", \",\", \" I\", \" got\", \" to\", \" make\", \" this\", \" work\", \"Cause\", \" if\", \" we\", \" don\", \"\\u00b4\", \"t\", \" we\", \"\\u00b4\", \"re\", \" left\", \" with\", \" nothing\", \" and\", \" that\", \"\\u00b4\", \"s\", \" what\", \" hurts\", \"We\", \"\\u00b4\", \"re\", \" so\", \" close\", \" to\", \" giving\", \" up\", \" but\", \"\\n\", \":\", \"\\\\newline\", \"\\\\newline\", \"HTTP\", \" Response\", \" in\", \" Android\", \" -\", \" Network\", \"On\", \"Main\", \"Thread\", \"Exception\", \"\\\\newline\", \"\\\\newline\", \"I\", \" want\", \" to\", \" check\", \" the\", \" HTTP\", \" response\", \" of\", \" a\", \" certain\", \" URL\", \" before\", \" loading\", \" into\", \" a\", \" web\", \"view\", \".\", \" I\", \" only\", \" want\", \" to\", \" load\", \" web\", \"\\n\", \" Of\", \" London\", \" 32\", \" Borough\", \"s\", \" Ne\", \"ighborhood\", \"s\", \" In\", \"\\\\newline\", \"\\\\newline\", \"Map\", \" Of\", \" London\", \" 32\", \" Borough\", \"s\", \" Ne\", \"ighborhood\", \"s\", \" In\", \"\\\\newline\", \"\\\\newline\", \"Map\", \" Of\", \" London\", \" 32\", \" Borough\", \"s\", \" Ne\", \"ighborhood\", \"s\", \" In\", \" is\", \" one\", \" from\", \" many\", \" image\", \" from\", \"\\n\", \"header\", \" class\", \"=\\\"\", \"header\", \"-\", \"wrapper\", \"\\\">\", \"\\\\newline\\\\newline\", \"  \", \"<\", \"nav\", \" class\", \"=\\\"\", \"inner\", \"\\\">\", \"\\\\newline\", \"    \", \"<\", \"div\", \" class\", \"=\\\"\", \"title\", \"\\\">\", \"\\\\newline\", \"      \", \"<\", \"a\", \" href\", \"=\\\"/\", \"\\\">\", \"\\\\newline\", \"        \", \"<\", \"img\", \" class\", \"=\\\"\", \"logo\", \"\\\"\", \" src\", \"\\n\", \"(\", \"alkyl\", \"cy\", \"ano\", \"acrylate\", \")\", \" nanoc\", \"aps\", \"ules\", \":\", \" physic\", \"ochemical\", \" characterization\", \" and\", \" mechanism\", \" of\", \" formation\", \".\", \"\\\\newline\", \"Nan\", \"oc\", \"aps\", \"ules\", \" of\", \" poly\", \"(\", \"is\", \"ob\", \"ut\", \"yl\", \"cy\", \"ano\", \"acrylate\", \")\", \" and\", \" poly\", \"(\", \"iso\", \"hex\", \"\\n\", \",\", \" How\", \" Effective\", \" is\", \" Milk\", \" For\", \" Cold\", \" So\", \"res\", \" Treatment\", \"\\u2026.\", \"\\\\newline\", \"\\\\newline\", \"You\", \" are\", \" definitely\", \" going\", \" to\", \" use\", \" milk\", \" for\", \" cold\", \" so\", \"res\", \" after\", \" learning\", \" the\", \" awesome\", \" benefits\", \" of\", \" milk\", \" for\", \" cold\", \" so\", \"res\", \" you\", \" will\", \" see\", \" in\", \"\\n\", \":\", \"\\\\newline\", \"\\\\newline\", \"How\", \" to\", \" say\", \" \\\"\", \"I\", \" look\", \"\\\"\", \" as\", \" in\", \" \\\"\", \"I\", \" look\", \" stupid\", \"\\\"\", \" in\", \" French\", \"?\", \"\\\\newline\", \"\\\\newline\", \"How\", \" do\", \" you\", \" say\", \" \\\"\", \"I\", \" look\", \"\\\"\", \" as\", \" in\", \" \\\"\", \"I\", \" look\", \" pretty\", \"/\", \"st\", \"upid\", \"\\n\"], \"activations\": [[[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]], [[1.1175870895385742e-07]], [[0.0]], [[0.0]], [[-1.4007091522216797e-06]], [[0.0]], [[0.0]], [[7.152557373046875e-06]], [[0.0]], [[0.0]], [[5.245208740234375e-06]], [[8.58306884765625e-05]], [[2.2172927856445312e-05]], [[-1.0609626770019531e-05]], [[0.0]], [[0.0]], [[0.0]], [[-2.193450927734375e-05]], [[-5.1975250244140625e-05]], [[-5.781650543212891e-06]], [[1.0579824447631836e-05]], [[8.717179298400879e-06]], [[0.00015497207641601562]], [[1.1920928955078125e-07]], [[-4.184246063232422e-05]], [[8.106231689453125e-06]], [[2.1457672119140625e-06]], [[-1.1324882507324219e-06]], [[0.0]], [[0.0]], [[-2.980232238769531e-07]], [[-3.5762786865234375e-06]], [[-1.0728836059570312e-06]], [[4.32133674621582e-06]], [[1.7881393432617188e-07]], [[-0.005278587341308594]], [[-0.011476516723632812]], [[-0.004693031311035156]], [[-1.5437602996826172e-05]], [[0.0]]], \"firstDimensionName\": \"Layer\", \"secondDimensionName\": \"Neuron\"}\n",
       "    )\n",
       "    </script>"
      ],
      "text/plain": [
       "<circuitsvis.utils.render.RenderedHTML at 0x7f1d7027c400>"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ablate_feature_direction_display(full_text, best_feature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logit_lens(model,best_feature, smaller_dict, layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_text = [\n",
    "    \"I can count up to: 2 4 8 16 32 64 128 256 512 1024 2048 4096 8192 16384 32768 7 6 12 16 18 20 22 24\",\n",
    "]\n",
    "visualize_text(custom_text, best_feature, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Centric Viewpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Go through datapoints & see if the features that activate on them make sense.\n",
    "d_point = 0\n",
    "# text = tokens_dataset[d_point]\n",
    "data_ind, sequence_pos = np.unravel_index(d_point, (datapoints, token_amount))\n",
    "feature_val, feature_ind = dictionary_activations[d_point].topk(10)\n",
    "data_ind = int(data_ind)\n",
    "sequence_pos = int(sequence_pos)\n",
    "full_tok = torch.tensor(dataset[data_ind][\"input_ids\"])\n",
    "full_text = []\n",
    "full_text.append(model.tokenizer.decode(full_tok))\n",
    "visualize_text(full_text, feature_ind, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check the neuron/residual basis\n",
    "When we look at the weights of a feature, we are seeing the literal dimensions from the residual stream/neurons being read from the feature. \n",
    "\n",
    "Here I'm visualizing the weight values for the residual stream. If there are outliers, then it's mainly reading from that dimension."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(weights*max_activation).topk(20), (weights*max_activation).topk(20, largest=False).values, (weights*max_activation > 0.2).sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepend/Append tokens\n",
    "We can iterate over all tokens to check which ones activate a feature a lot to more rigorously test a hypothesis on what a feature means.\n",
    "\n",
    "Note: I'm literately running the model through all 50k tokens prepended to the text here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepend_all_tokens_and_get_feature_activation(model, minimal_activating_example, feature, setting=\"prepend\"):\n",
    "    tokens = model.to_tokens(minimal_activating_example, prepend_bos=False)\n",
    "\n",
    "    # Run through every number up to vocab size\n",
    "    vocab_size = model.cfg.d_vocab\n",
    "    batch_size = 256*2 # Define your desired batch size\n",
    "\n",
    "    dollar_feature_activations = torch.zeros(vocab_size)\n",
    "    for start in range(0, vocab_size, batch_size):\n",
    "        end = min(start + batch_size, vocab_size)\n",
    "\n",
    "        token_prep = torch.arange(start, end).to(device)\n",
    "        token_prep = token_prep.unsqueeze(1)  # Add a dimension for concatenation\n",
    "\n",
    "        # 1. Prepend to the tokens\n",
    "        if setting == \"prepend\":\n",
    "            tokens_catted = torch.cat((token_prep, tokens.repeat(end - start, 1)), dim=1).long()\n",
    "        elif setting == \"append\":\n",
    "            tokens_catted = torch.cat((tokens.repeat(end - start, 1), token_prep), dim=1).long()\n",
    "        else:\n",
    "            raise ValueError(f\"Unknown setting: {setting}\")\n",
    "\n",
    "        # 2. Run through the model\n",
    "        with torch.no_grad():\n",
    "            _, cache = model.run_with_cache(tokens_catted.to(device))\n",
    "            neuron_act_batch = cache[cache_name]\n",
    "            _, act = smaller_auto_encoder(neuron_act_batch)\n",
    "\n",
    "        # 3. Get the feature\n",
    "        dollar_feature_activations[start:end] = act[:, -1, feature].cpu().squeeze()\n",
    "\n",
    "    k = 20\n",
    "    k_increasing_val, k_increasing_ind = dollar_feature_activations.topk(k)\n",
    "    k_decreasing_val, k_decreasing_ind = dollar_feature_activations.topk(k, largest=False)\n",
    "    if(setting == \"prepend\"):\n",
    "        print(f\"[token]{minimal_activating_example}\")\n",
    "    elif(setting == \"append\"):\n",
    "        print(f\"{minimal_activating_example}[token]\")\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown setting: {setting}\")\n",
    "    # Print indices converted to tokens\n",
    "    print(f\"Top-{k} increasing: {model.to_str_tokens(k_increasing_ind)}\")\n",
    "    # Print values\n",
    "    print(f\"Top-{k} increasing: {[f'{val:.2f}' for val in k_increasing_val]}\")\n",
    "    print(f\"Top-{k} decreasing: {model.to_str_tokens(k_decreasing_ind)}\")\n",
    "    print(f\"Top-{k} decreasing: {[f'{val:.2f}' for val in k_decreasing_val]}\")\n",
    "    print(f\"Number of 0 activations: {torch.sum(dollar_feature_activations == 0)}\")\n",
    "    if(setting == \"prepend\"):\n",
    "        best_text = \"\".join(model.to_str_tokens(dollar_feature_activations.argmax()) + [minimal_activating_example])\n",
    "    else:\n",
    "        best_text = \"\".join([minimal_activating_example] + model.to_str_tokens(dollar_feature_activations.argmax()))\n",
    "    return best_text\n",
    "\n",
    "best_text = \"\"\n",
    "for x in range(3):\n",
    "    # best_text = prepend_all_tokens_and_get_feature_activation(model, best_text, best_feature, setting=\"prepend\")\n",
    "    best_text = prepend_all_tokens_and_get_feature_activation(model, best_text, best_feature, setting=\"append\")\n",
    "# prepend_all_tokens_and_get_feature_activation(model, \" for all $\", best_feature, setting=\"prepend\")\n",
    "# prepend_all_tokens_and_get_feature_activation(model, \" tree\", best_feature, setting=\"prepend\")\n",
    "# prepend_all_tokens_and_get_feature_activation(model, \" tree\", best_feature, setting=\"append\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_text = \"\"\n",
    "for x in range(3):\n",
    "    best_text = prepend_all_tokens_and_get_feature_activation(model, best_text, best_feature, setting=\"prepend\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prepend_all_tokens_and_get_feature_activation(model, \" for all $\", best_feature, setting=\"prepend\")\n",
    "# prepend_all_tokens_and_get_feature_activation(model, \"The\", best_feature, setting=\"append\")\n",
    "# prepend_all_tokens_and_get_feature_activation(model, \" tree\", best_feature, setting=\"append\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16 (main, Dec  7 2022, 10:02:13) \n[Clang 14.0.0 (clang-1400.0.29.202)]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2313d0e6d125794df61c1d508f7cadabba0eaf461173e1d387638b82cb721b48"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
